import numpy as np
import torch
from torch.distributions import one_hot_categorical
import time


class RolloutWorker:
    def __init__(self, env, agents, args):
        self.env = env
        self.agents = agents
        self.episode_limit = args.episode_limit
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        self.args = args

        self.epsilon = args.epsilon
        self.anneal_epsilon = args.anneal_epsilon
        self.min_epsilon = args.min_epsilon
        print('Init RolloutWorker')

    def generate_episode(self, episode_num=None, evaluate=False):
        if self.args.replay_dir != '' and evaluate and episode_num == 0:  # prepare for save replay of evaluation
            self.env.close()
        o, u, r, s, avail_u, u_onehot, terminate, padded, p_s = [], [], [], [], [], [], [], [], []
        self.env.reset()
        terminated = False
        win_tag = False
        step = 0
        episode_reward = 0  # cumulative rewards
        last_action = np.zeros((self.args.n_agents, self.args.n_actions))
        self.agents.policy.init_hidden(1)
        action_onehot = np.eye(self.n_actions)

        # epsilon
        epsilon = 0 if evaluate else self.epsilon
        if self.args.epsilon_anneal_scale == 'episode':
            epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon

        # sample z for maven
        if self.args.alg == 'maven':
            state = self.env.get_state()
            state = torch.tensor(state, dtype=torch.float32)
            if self.args.cuda:
                state = state.cuda()
            z_prob = self.agents.policy.z_policy(state)
            maven_z = one_hot_categorical.OneHotCategorical(z_prob).sample()
            maven_z = list(maven_z.cpu())

        while not terminated and step < self.episode_limit:
            # time.sleep(0.2)
            obs = self.env.get_obs()
            state = self.env.get_state()
            p_state = self.env.get_p_state()

            actions, avail_actions, actions_onehot = [], [], []
            avail_action = self.env.get_avail_agent_actions(id=0)
            if self.args.alg == 'maven':
                action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                                   avail_action, epsilon, maven_z, evaluate)
            else:
                agent_id = 0
                action = self.agents.choose_action(obs, last_action, agent_id,
                                                   avail_action, epsilon, evaluate)
            actions = np.array(action.to('cpu'))
            actions_onehot = action_onehot[actions]
            avail_actions = avail_action
            last_action = actions_onehot
            # for agent_id in range(self.n_agents):
            #     avail_action = self.env.get_avail_agent_actions(agent_id)
            #
            #     # generate onehot vector of th action
            #     action_onehot = np.zeros(self.args.n_actions)
            #     action_onehot[action] = 1
            #     actions.append(np.int(action))
            #     actions_onehot.append(action_onehot)
            #     avail_actions.append(avail_action)
            #     last_action[agent_id] = action_onehot

            reward, terminated, info = self.env.step(actions)
            if True:
                win_tag = 1 if (reward > 8) or ('battle_won' in info and terminated and info['battle_won']) else 0
            else:
                win_tag = reward
                # win_tag = True if  else False
            o.append(obs)
            s.append(state)
            p_s.append(p_state)
            u.append(np.reshape(actions, [self.n_agents, 1]))
            u_onehot.append(actions_onehot)
            avail_u.append(avail_actions)
            r.append([reward])
            episode_reward += reward
            step += 1
            # if terminated:print(step,win_tag)
            # if step==51:#56:
            #     print(step)
            if self.args.epsilon_anneal_scale == 'step':
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        # last obs
        obs = self.env.get_obs()
        state = self.env.get_state()
        p_state = self.env.get_p_state()
        o.append(obs)
        s.append(state)
        o_next = o[1:]
        s_next = s[1:]
        o = o[:-1]
        s = s[:-1]
        # get avail_action for last obs，because target_q needs avail_action in training
        # avail_actions = []
        # for agent_id in range(self.n_agents):
        #     avail_action = self.env.get_avail_agent_actions(agent_id)
        #     avail_actions.append(avail_action)
        avail_u.append(avail_action)
        avail_u_next = avail_u[1:]
        avail_u = avail_u[:-1]

        remain_length = self.episode_limit - step
        o = o + [np.zeros((self.n_agents, self.obs_shape))] * remain_length
        o_next = o_next + [np.zeros((self.n_agents, self.obs_shape))] * remain_length
        u = u + [np.zeros((self.n_agents, 1))] * remain_length
        u_onehot = u_onehot + \
                   [np.zeros((self.n_agents, self.n_actions))] * remain_length
        s = s + [np.zeros((self.state_shape))] * remain_length
        s_next = s_next + [np.zeros((self.state_shape))] * remain_length
        r = r + [[0]] * remain_length
        avail_u = avail_u + [np.ones((self.n_agents, self.n_actions))] * remain_length
        avail_u_next = avail_u_next + \
                       [np.ones((self.n_agents, self.n_actions))] * remain_length
        padded = [[0.]] * step + [[1.]] * remain_length
        terminate = [[0.]] * step + [[1.]] * remain_length
        p_s = p_s + [p_state] * remain_length



        episode = dict(o=o.copy(),
                       s=s.copy(),
                       u=u.copy(),
                       r=r.copy(),
                       avail_u=avail_u.copy(),
                       o_next=o_next.copy(),
                       s_next=s_next.copy(),
                       avail_u_next=avail_u_next.copy(),
                       u_onehot=u_onehot.copy(),
                       padded=padded.copy(),
                       terminated=terminate.copy(),
                       p_state=p_s.copy()
                       )
        # add episode dim
        for key in episode.keys():
            episode[key] = np.array([episode[key]])
        if not evaluate:
            self.epsilon = epsilon
        if self.args.alg == 'maven':
            episode['z'] = np.array([maven_z.copy()])
        if evaluate and episode_num == self.args.evaluate_epoch - 1 and self.args.replay_dir != '':
            self.env.save_replay()
            self.env.close()
        return episode, episode_reward, win_tag, step

class mavenRolloutWorker:
    def __init__(self, env, agents, args):
        self.env = env
        self.agents = agents
        self.episode_limit = args.episode_limit
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        self.args = args

        self.epsilon = args.epsilon
        self.anneal_epsilon = args.anneal_epsilon
        self.min_epsilon = args.min_epsilon
        print('Init RolloutWorker')

    def generate_episode(self, episode_num=None, evaluate=False):
        if self.args.replay_dir != '' and evaluate and episode_num == 0:  # prepare for save replay of evaluation
            self.env.close()
        o, u, r, s, avail_u, u_onehot, terminate, padded, p_s = [], [], [], [], [], [], [], [], []
        self.env.reset()
        terminated = False
        win_tag = False
        step = 0
        episode_reward = 0  # cumulative rewards
        last_action = np.zeros((self.args.n_agents, self.args.n_actions))
        self.agents.policy.init_hidden(1)
        action_onehot = np.eye(self.n_actions)

        # epsilon
        epsilon = 0 if evaluate else self.epsilon
        if self.args.epsilon_anneal_scale == 'episode':
            epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon

        # sample z for maven
        if self.args.alg == 'maven':
            state = self.env.get_state()
            state = torch.tensor(state, dtype=torch.float32)
            if self.args.cuda:
                state = state.cuda()
            z_prob = self.agents.policy.z_policy(state)
            maven_z = one_hot_categorical.OneHotCategorical(z_prob).sample()
            maven_z = list(maven_z.cpu())

        while not terminated and step < self.episode_limit:
            # time.sleep(0.2)
            obs = self.env.get_obs()
            state = self.env.get_state()
            p_state = self.env.get_p_state()

            actions, avail_actions, actions_onehot = [], [], []
            avail_action = self.env.get_avail_agent_actions(id=0)
            agent_id = 0
            if self.args.alg == 'maven':
                action = self.agents.choose_action(obs, last_action, agent_id,
                                               avail_action, epsilon, maven_z, evaluate)
            else:
                action = self.agents.choose_action(obs, last_action, agent_id,
                                                   avail_action, epsilon, evaluate)
            actions = np.array(action.to('cpu'))
            actions_onehot = action_onehot[actions]
            avail_actions = avail_action
            last_action = actions_onehot


            reward, terminated, info = self.env.step(actions)
            if 'vs' in self.args.env:
                win_tag = 1 if (reward > 50) or ('battle_won' in info and terminated and info['battle_won']) else 0
            else:
                win_tag = 1 if (reward > 50) or ('battle_won' in info and terminated and info['battle_won']) else 0
                # win_tag = True if  else False
            o.append(obs)
            s.append(state)
            p_s.append(p_state)
            u.append(np.reshape(actions, [self.n_agents, 1]))
            u_onehot.append(actions_onehot)
            avail_u.append(avail_actions)
            r.append([reward])
            episode_reward += reward
            step += 1
            if self.args.epsilon_anneal_scale == 'step':
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        # last obs
        obs = self.env.get_obs()
        state = self.env.get_state()
        p_state = self.env.get_p_state()
        o.append(obs)
        s.append(state)
        o_next = o[1:]
        s_next = s[1:]
        o = o[:-1]
        s = s[:-1]
        # get avail_action for last obs，because target_q needs avail_action in training
        # avail_actions = []
        # for agent_id in range(self.n_agents):
        #     avail_action = self.env.get_avail_agent_actions(agent_id)
        #     avail_actions.append(avail_action)
        avail_u.append(avail_action)
        avail_u_next = avail_u[1:]
        avail_u = avail_u[:-1]

        remain_length = self.episode_limit - step
        o = o + [np.zeros((self.n_agents, self.obs_shape))] * remain_length
        o_next = o_next + [np.zeros((self.n_agents, self.obs_shape))] * remain_length
        u = u + [np.zeros((self.n_agents, 1))] * remain_length
        u_onehot = u_onehot + \
                   [np.zeros((self.n_agents, self.n_actions))] * remain_length
        s = s + [np.zeros((self.state_shape))] * remain_length
        s_next = s_next + [np.zeros((self.state_shape))] * remain_length
        r = r + [[0]] * remain_length
        avail_u = avail_u + [np.ones((self.n_agents, self.n_actions))] * remain_length
        avail_u_next = avail_u_next + \
                       [np.ones((self.n_agents, self.n_actions))] * remain_length
        padded = [[0.]] * step + [[1.]] * remain_length
        terminate = [[0.]] * step + [[1.]] * remain_length
        p_s = p_s + [p_state] * remain_length

        # if step < self.episode_limit，padding
        # for i in range(step, self.episode_limit):
        #     o.append(np.zeros((self.n_agents, self.obs_shape)))
        #     u.append(np.zeros([self.n_agents, 1]))
        #     s.append(np.zeros(self.state_shape))
        #     r.append([0.])
        #     o_next.append(np.zeros((self.n_agents, self.obs_shape)))
        #     s_next.append(np.zeros(self.state_shape))
        #     u_onehot.append(np.zeros((self.n_agents, self.n_actions)))
        #     avail_u.append(np.zeros((self.n_agents, self.n_actions)))
        #     avail_u_next.append(np.zeros((self.n_agents, self.n_actions)))
        #     padded.append([1.])
        #     terminate.append([1.])

        episode = dict(o=o.copy(),
                       s=s.copy(),
                       u=u.copy(),
                       r=r.copy(),
                       avail_u=avail_u.copy(),
                       o_next=o_next.copy(),
                       s_next=s_next.copy(),
                       avail_u_next=avail_u_next.copy(),
                       u_onehot=u_onehot.copy(),
                       padded=padded.copy(),
                       terminated=terminate.copy(),
                       p_state=p_s.copy()
                       )
        # add episode dim
        for key in episode.keys():
            episode[key] = np.array([episode[key]])
        if not evaluate:
            self.epsilon = epsilon
        if self.args.alg == 'maven':
            episode['z'] = np.array([maven_z.copy()])
        if evaluate and episode_num == self.args.evaluate_epoch - 1 and self.args.replay_dir != '':
            self.env.save_replay()
            self.env.close()
        return episode, episode_reward, win_tag, step

# RolloutWorker for communication
class CommRolloutWorker:
    def __init__(self, env, agents, args):
        self.env = env
        self.agents = agents
        self.episode_limit = args.episode_limit
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        self.args = args

        self.epsilon = args.epsilon
        self.anneal_epsilon = args.anneal_epsilon
        self.min_epsilon = args.min_epsilon
        print('Init CommRolloutWorker')

    def generate_episode(self, episode_num=None, evaluate=False):
        if self.args.replay_dir != '' and evaluate and episode_num == 0:  # prepare for save replay
            self.env.close()
        o, u, r, s, avail_u, u_onehot, terminate, padded = [], [], [], [], [], [], [], []
        self.env.reset()
        terminated = False
        win_tag = False
        step = 0
        episode_reward = 0
        last_action = np.zeros((self.args.n_agents, self.args.n_actions))
        self.agents.policy.init_hidden(1)
        epsilon = 0 if evaluate else self.epsilon
        if self.args.epsilon_anneal_scale == 'episode':
            epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        while not terminated and step < self.episode_limit:
            # time.sleep(0.2)
            obs = self.env.get_obs()
            state = self.env.get_state()
            actions, avail_actions, actions_onehot = [], [], []

            # get the weights of all actions for all agents
            weights = self.agents.get_action_weights(np.array(obs), last_action)

            # choose action for each agent
            for agent_id in range(self.n_agents):
                avail_action = self.env.get_avail_agent_actions(agent_id)
                action = self.agents.choose_action(weights[agent_id], avail_action, epsilon, evaluate)

                # generate onehot vector of th action
                action_onehot = np.zeros(self.args.n_actions)
                action_onehot[action] = 1
                actions.append(np.int(action))
                actions_onehot.append(action_onehot)
                avail_actions.append(avail_action)
                last_action[agent_id] = action_onehot

            reward, terminated, info = self.env.step(actions)
            win_tag = True if terminated and 'battle_won' in info and info['battle_won'] else False
            o.append(obs)
            s.append(state)
            u.append(np.reshape(actions, [self.n_agents, 1]))
            u_onehot.append(actions_onehot)
            avail_u.append(avail_actions)
            r.append([reward])

            episode_reward += reward
            step += 1
            # if terminated:
            #     time.sleep(1)
            if self.args.epsilon_anneal_scale == 'step':
                epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon
        # last obs
        obs = self.env.get_obs()
        state = self.env.get_state()
        o.append(obs)
        s.append(state)
        o_next = o[1:]
        s_next = s[1:]
        o = o[:-1]
        s = s[:-1]
        # get avail_action for last obs，because target_q needs avail_action in training
        avail_actions = []
        for agent_id in range(self.n_agents):
            avail_action = self.env.get_avail_agent_actions(agent_id)
            avail_actions.append(avail_action)
        avail_u.append(avail_actions)
        avail_u_next = avail_u[1:]
        avail_u = avail_u[:-1]

        remain_length = self.episode_limit - step
        o = o + [np.zeros((self.n_agents, self.obs_shape))] * remain_length
        o_next = o_next + [np.zeros((self.n_agents, self.obs_shape))] * remain_length
        u = u + [np.zeros((self.n_agents, 1))] * remain_length
        u_onehot = u_onehot + \
                   [np.zeros((self.n_agents, self.n_actions))] * remain_length
        s = s + [np.zeros((self.state_shape))] * remain_length
        s_next = s_next + [np.zeros((self.state_shape))] * remain_length
        r = r + [[0]] * remain_length
        avail_u = avail_u + [np.ones((self.n_agents, self.n_actions))] * remain_length
        avail_u_next = avail_u_next + \
                       [np.ones((self.n_agents, self.n_actions))] * remain_length
        padded = [[0.]] * step + [[1]] * remain_length

        episode = dict(o=o.copy(),
                       s=s.copy(),
                       u=u.copy(),
                       r=r.copy(),
                       avail_u=avail_u.copy(),
                       o_next=o_next.copy(),
                       s_next=s_next.copy(),
                       avail_u_next=avail_u_next.copy(),
                       u_onehot=u_onehot.copy(),
                       padded=padded.copy(),
                       terminated=terminate.copy()
                       )
        # add episode dim
        for key in episode.keys():
            episode[key] = np.array([episode[key]])
        if not evaluate:
            self.epsilon = epsilon
            # print('Epsilon is ', self.epsilon)
        if evaluate and episode_num == self.args.evaluate_epoch - 1 and self.args.replay_dir != '':
            self.env.save_replay()
            self.env.close()
        return episode, episode_reward, win_tag, step
